{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "import tensorflow as tf\n",
    "from sklearn.cross_validation import train_test_split\n",
    "import time\n",
    "import random\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
    "trainset.data, trainset.target = separate_dataset(trainset,1.0)\n",
    "print (trainset.target_names)\n",
    "print (len(trainset.data))\n",
    "print (len(trainset.target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concat = ' '.join(trainset.data).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary['GO']\n",
    "PAD = dictionary['PAD']\n",
    "EOS = dictionary['EOS']\n",
    "UNK = dictionary['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_size = 128\n",
    "dimension_output = len(trainset.target_names)\n",
    "maxlen = 50\n",
    "batch_size = 32\n",
    "kernel_size = 3\n",
    "num_filters = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, \n",
    "                 maxlen,\n",
    "                 dimension_output,\n",
    "                 vocab_size,\n",
    "                 embedding_size,\n",
    "                 kernel_size,\n",
    "                 num_filters,\n",
    "                 learning_rate):\n",
    "        self.X = tf.placeholder(tf.int32,[None, maxlen])\n",
    "        self.Y = tf.placeholder(tf.int32,[None])\n",
    "        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1, 1))\n",
    "        embedded = tf.nn.embedding_lookup(embeddings, self.X)\n",
    "        first_region = tf.layers.conv1d(\n",
    "                    embedded,\n",
    "                    num_filters,\n",
    "                    kernel_size = kernel_size,\n",
    "                    strides = 1,\n",
    "                    padding = 'valid'\n",
    "                )\n",
    "        forward = tf.nn.relu(first_region)\n",
    "        forward = tf.layers.conv1d(\n",
    "                    forward,\n",
    "                    num_filters,\n",
    "                    kernel_size = kernel_size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same'\n",
    "                )\n",
    "        forward = tf.layers.batch_normalization(forward)\n",
    "        forward = tf.nn.relu(first_region)\n",
    "        forward = tf.layers.conv1d(\n",
    "                    forward,\n",
    "                    num_filters,\n",
    "                    kernel_size = kernel_size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same'\n",
    "                )\n",
    "        forward = tf.layers.batch_normalization(forward)\n",
    "        forward = tf.nn.relu(first_region)\n",
    "        forward = forward + first_region\n",
    "        \n",
    "        def _block(x):\n",
    "            x = tf.pad(x, paddings=[[0, 0], [0, 1], [0, 0]])\n",
    "            px = tf.layers.max_pooling1d(x, 3, 2)\n",
    "            x = tf.nn.relu(px)\n",
    "            x = tf.layers.conv1d(\n",
    "                    x,\n",
    "                    num_filters,\n",
    "                    kernel_size = kernel_size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same'\n",
    "                )\n",
    "            x = tf.layers.batch_normalization(x)\n",
    "            x = tf.nn.relu(x)\n",
    "            x = tf.layers.conv1d(\n",
    "                    x,\n",
    "                    num_filters,\n",
    "                    kernel_size = kernel_size,\n",
    "                    strides = 1,\n",
    "                    padding = 'same'\n",
    "                )\n",
    "            x = tf.layers.batch_normalization(x)\n",
    "            x = x + px\n",
    "            return x\n",
    "        while forward.get_shape().as_list()[1] >= 2:\n",
    "            forward = _block(forward)\n",
    "        self.logits = tf.reduce_sum(tf.layers.conv1d(\n",
    "            forward, dimension_output, kernel_size = 1, strides = 1, padding = 'SAME'\n",
    "        ), 1)\n",
    "        self.cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=self.logits,\n",
    "            labels=self.Y))\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        correct_pred = tf.equal(tf.argmax(self.logits, 1,output_type=tf.int32), self.Y)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(maxlen, dimension_output, len(dictionary), embedding_size,\n",
    "             kernel_size, num_filters, 1e-3)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectors = str_idx(trainset.data,dictionary,maxlen)\n",
    "train_X, test_X, train_Y, test_Y = train_test_split(vectors, trainset.target,test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_y = train_Y[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_x_expand = np.expand_dims(batch_x,axis = 1)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_x_expand = np.expand_dims(batch_x,axis = 1)\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x\n",
    "            },\n",
    "        )\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss /= len(train_X) / batch_size\n",
    "    train_acc /= len(train_X) / batch_size\n",
    "    test_loss /= len(test_X) / batch_size\n",
    "    test_acc /= len(test_X) / batch_size\n",
    "\n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "    batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "    predict_Y += np.argmax(\n",
    "        sess.run(\n",
    "            model.logits, feed_dict = {model.X: batch_x, model.Y: batch_y}\n",
    "        ),\n",
    "        1,\n",
    "    ).tolist()\n",
    "    real_Y += batch_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(metrics.classification_report(real_Y, predict_Y, target_names = trainset.target_names))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
